###########################################################
# Kerice Doten-Snitker
# 
# Helper functions for getting model results from rstan-based objects
# in order to print pretty tables
#
# based on: https://www.andrewheiss.com/blog/2018/03/08/amelia-broom-huxtable/
# 
# Built under R 3.4.3 
# Platform: x86_64-apple-darwin15.6.0 (64-bit)
###########################################################

tidy.stanreg <- function(x, cred.int = FALSE, cred.level = 0.89, digs = 2, mixed.levels = 0) {
  # Get the df from one of the models
  model_degrees_freedom <- x$df.residual

  # Create tidy output
  output <- round(as.data.frame(cbind(x$coefficients,
                                x$ses)), digs) %>%
    magrittr::set_colnames(c("estimate", "std.error")) %>%
    rownames_to_column(var="terms")

  # Add credibility intervals if needed
  if (cred.int & cred.level) {

    pstr <- posterior_interval(x, prob = cred.level)

    output <- output %>% 
      mutate(cred.low = round(pstr[1:(nrow(pstr)-mixed.levels),1],digs), # last rows are sigmas
             cred.high = round(pstr[1:(nrow(pstr)-mixed.levels),2],digs)) %>%
      dplyr::select(terms, estimate, std.error, cred.low, cred.high)
  }

  # tidy objects only have a data.frame class, not tbl_df or anything else
  class(output) <- "data.frame"
  output
}

tidy.stansummary <- function(fitted, cred.levels = c(0.055, 0.5, 0.945), digs = 2) {
  # Create tidy output
  output <- as.data.frame(summary(fitted$stanfit, probs = cred.levels)$summary) %>%
    round(., digs) %>%
    rownames_to_column(var="terms")
  
  # tidy objects only have a data.frame class, not tbl_df or anything else
  class(output) <- "data.frame"
  output
}

sampler_param_summary <- function(fitted, params, summary_stat){
  # "accept_stat__", "stepsize__", "treedepth__", "n_leapfrog__", "divergent__", "energy__" 
  sampler_params <- get_sampler_params(fitted$stanfit, inc_warmup = FALSE)
  stat_by_chain <- sapply(sampler_params, function(x) sum(x[, params]))
  print(summary_stat(stat_by_chain))
}

kfold_param_summary <- function(fits){
  kfold_elpds <- lapply(fits, "[[", "elpd_kfold")
  kfold_ses <- lapply(fits, "[[", "se_elpd_kfold")
  kfold_table <- data.frame(kfold_elpd = unlist(kfold_elpds),
                            kfold_se = unlist(kfold_ses))
  print(kfold_table)
}

loo_param_summary <- function(fits){
  loo_elpds <- lapply(fits, "[[", "elpd_loo")
  loo_elpd_ses <- lapply(fits, "[[", "se_elpd_loo")
  loo_p_loos <- lapply(fits, "[[", "p_loo")
  loo_p_loo_ses <- lapply(fits, "[[", "se_p_loo")
  loo_ics <- lapply(fits, "[[", "looic")
  loo_ic_ses <- lapply(fits, "[[", "se_looic")
  loo_table <- data.frame(loo_elpd = unlist(loo_elpds),
                          loo_se = unlist(loo_elpd_ses),
                          loo_p_loo = unlist(loo_p_loos),
                          loo_p_loo_se = unlist(loo_p_loo_ses),
                          loo_ic = unlist(loo_ics),
                          loo_ic_se = unlist(loo_ic_ses),
                          row.names = unlist(lapply(fits, function (x) attr(x, "model_name"))))
  print(loo_table)
}

brms_controlparam_summary <- function(fits){
  fit_adapt_delta <- lapply(fits, function(x) control_params(x)[c("adapt_delta")])
  fit_stepsize <- lapply(fits, function(x) control_params(x)[c("stepsize")])
  fit_divergent <- lapply(fits, function(x) rstan::get_num_divergent(x$fit))
  param_table <- data.frame(adapt_delta = unlist(fit_adapt_delta),
                            stepsize = unlist(fit_stepsize),
                            divergent = unlist(fit_divergent))
  print(param_table)
}
